#-----------------------------------------------------------------------------
#
# Replication code for Callaway and Li (2022)
#
# Authors: Brantly Callaway and Tong Li
# Date:    2/7/2023
#
#-----------------------------------------------------------------------------


#-----------------------------------------------------------------------------
# load external packages
#-----------------------------------------------------------------------------
library(ggplot2)
library(tidyr)
library(MatchIt)
library(cobalt)
library(WeightIt)
library(ppe)
library(pte)
library(DRDID)

#-----------------------------------------------------------------------------
# load data
#-----------------------------------------------------------------------------
# main dataset
load("southeast_covid_data.RData")

# subset of data for travel results
load("travel_subset.RData")

# shorten name
dta <- southeast_covid_data
rm(southeast_covid_data)
dta2 <- travel_subset
rm(travel_subset)

# script parameters
treated_state <- "Tennessee"
comp_states <- sort(c("Alabama", "Arkansas", "Georgia", "North Carolina", "Kentucky", "Mississippi"))
start_date <- as.Date("2020-03-16")  
end_date <- as.Date("2020-05-10")
policy_date <- as.Date("2020-04-01")
group <- as.numeric(policy_date - start_date)
dta$group <- group*(dta$state==treated_state)
dta2$group <- group*(dta2$state==treated_state)

# most results use the following formula for the covariates
xformla <- ~ lag_cum_tests.pc + lag_cum_cases.pc + lpop + lag_cum_deaths.pc #+ dcases.pc
matching_formula <- xformla

# code for producing Table 1, summary statistics and covariate balance
this_data <- subset(dta, date==policy_date-1)
wres <- weightit(BMisc::toformula("D",  BMisc::rhs.vars(xformla)),
                 data=this_data,
                 estimand="ATT",
                 method="ps",
                 stabilize=TRUE)

# print underlying data for Table 1
bal.tab(wres, stats = c("m", "v"), thresholds = c(m = .05),
        data=this_data,
        un=TRUE,
        disp="means")

#-----------------------------------------------------------------------------
# helper functions
#-----------------------------------------------------------------------------
# function to make plots reported in the paper
make_pandemic_plot <- function(pte_res, ylabel=NULL, xlabel=NULL, ylimits=NULL) {
  p <- ggpte(pte_res) + scale_x_continuous(breaks=seq(-28,28,by=7))
  p$data$e <- as.Date(policy_date+p$data$e)
  p <- p + scale_x_date(breaks=seq(start_date+5,end_date,by=7), date_labels="%b %d", limits=c(start_date+1, end_date-1))
  if (!is.null(ylimits)) p <- p + ylim(ylimits)
  if (!is.null(ylabel)) p <- p + ylab(ylabel)
  if (!is.null(xlabel)) p <- p + xlab(xlabel)
  p <- p + theme(legend.position="none")
  print(p)
}


#-----------------------------------------------------------------------------
# code for producing main results
#-----------------------------------------------------------------------------

# test outcome
tests_name <- "cum_tests.pc"  # to use cumulative tests
#  tests_name <- "tests.pc.ma" # to use 7 day moving average

# confirmed cases outcome
cases_name <- "cum_cases.pc"  # to use cumulative 
# cases_name <- "cases.pc.ma" # to use 7 day moving average

this_data <- dta
this_data2 <- dta2
attgt_fun <- weighted_reg_attgt
weighting_method <- "ps"


#-----------------------------------------------------------------------------
# additional code for matching
#   - uncomment below code to use matching
#-----------------------------------------------------------------------------
# matching_formula <- xformla
# weighting_method <- NULL
# this_data <- subset(dta, fips != 47157) # drops Memphis
# this_data2 <- subset(dta, fips != 47157) # drops Memphis
# attgt_fun <- matching_attgt
# xformla <- ~1




# compute effects of policy on tests
tests_res <- pte(yname=tests_name,
                 gname="group",
                 tname="time.period",
                 idname="fips",
                 data=this_data,
                 subset_fun=two_by_two_subset,
                 setup_pte_fun=setup_pte_basic,
                 attgt_fun=attgt_fun,
                 cband=FALSE,
                 alp=0.1,
                 weighting_method=weighting_method,
                 ret_imputation=TRUE,
                 matching_formula=matching_formula,
                 xformla=xformla)

make_pandemic_plot(tests_res, xlabel="", ylabel="Cumulative Tests per 1000 People", ylimits=NULL)

# effect of policy on confirmed cases
cases_res <- pte(yname=cases_name,
                 gname="group",
                 tname="time.period",
                 idname="fips",
                 data=this_data,
                 subset_fun=two_by_two_subset,
                 setup_pte_fun=setup_pte_basic,
                 attgt_fun=attgt_fun,
                 cband=FALSE,
                 alp=0.1,
                 weighting_method=weighting_method,
                 ret_imputation=TRUE,
                 matching_formula=matching_formula,
                 xformla=xformla)

make_pandemic_plot(cases_res, xlabel="", ylabel="Cumulative Confirmed Cases per 1000 People")


# code for producing bounds
compute_bounds <- function(this_date) {
  this_idx <- as.numeric(this_date - start_date) - 1 
  tn.R <- cases_res$att_gt$extra_gt_returns[[this_idx]]$extra_gt_returns$actual / 1000
  tn.T <- tests_res$att_gt$extra_gt_returns[[this_idx]]$extra_gt_returns$actual / 1000
  tn.RcondT1 <- tn.R/tn.T

  cf.R <- cases_res$att_gt$extra_gt_returns[[this_idx]]$extra_gt_returns$imputation / 1000
  cf.T <- tests_res$att_gt$extra_gt_returns[[this_idx]]$extra_gt_returns$imputation / 1000
  cf.RcondT1 <- cf.R/cf.T

  false.neg <- 0.25
  tn.gam <- tn.R/(1-false.neg)
  cf.gam <- cf.R/(1-false.neg)

  tn.CcondT1 <- tn.RcondT1/(1-false.neg)
  cf.CcondT1 <- cf.RcondT1/(1-false.neg)


  # bounds on treatment effects
  te.Lower.B <- (tn.gam - cf.gam) - cf.gam * (1-cf.T)/cf.T
  te.Upper.B <- (tn.gam - cf.gam) + tn.gam * (1-tn.T)/tn.T


  te.Lower.C <- te.Lower.B
  te.Upper.C <- tn.gam - cf.gam

  upper_se <- cases_res$event_study$se.egt[this_idx]/(1-false.neg)
  upper_ci <- mean(te.Upper.C)*1000 + qnorm(.9)*upper_se

  return(list(this_date=this_date,lower_bound=mean(te.Lower.C)*1000, upper_bound=mean(te.Upper.C)*1000, upper_ci=upper_ci))
}

date_range <- seq(start_date+2, end_date-1, by=1)
bounds_list <- lapply(date_range, compute_bounds)
bounds_res <- do.call(rbind.data.frame, bounds_list)
bounds_res$this_date <- date_range


plot_df <- tidyr::pivot_longer(bounds_res, cols=c(upper_bound, lower_bound))
plot_df$name <- factor(plot_df$name, levels=c("upper_bound", "lower_bound"))
levels(plot_df$name) <- c("Upper Bound", "Lower Bound")
plot_df$upper_ci[plot_df$name=="Lower Bound"] <- NA
plot_df$post <- as.factor(plot_df$this_date >= policy_date)

ggplot(plot_df, aes(x=this_date, y=value)) +
  facet_grid(name ~ ., scale="free_y") + 
  geom_line(aes(color=post)) +
  geom_point(aes(color=post), size=1.1) +
  geom_line(aes(y=upper_ci), linetype="dashed") + 
  ylab("Cumulative Infections per 1000 People") +
  xlab("") +
  scale_x_date(breaks=seq(start_date+5,end_date,by=7), date_labels="%b %d", limits=c(start_date+1, end_date-1)) +
  theme_bw() + 
  theme(legend.position="none") 


#-----------------------------------------------------------------------------
# additional code for producing policy effects on deaths 
#-----------------------------------------------------------------------------

deaths_name <- "cum_deaths.pc"
# deaths_name <- "deaths.pc.ma"

# effect of policy on deaths
deaths_res <- pte(yname=deaths_name,
                  gname="group",
                  tname="time.period",
                  idname="fips",
                  data=this_data,
                  subset_fun=two_by_two_subset,
                  setup_pte_fun=setup_pte_basic,
                  attgt_fun=attgt_fun,
                  cband=FALSE,
                  alp=0.1,
                  weighting_method=weighting_method,
                  ret_imputation=TRUE,
                  matching_formula=matching_formula,
                  xformla=xformla)

make_pandemic_plot(deaths_res, xlabel="", ylabel="Cumulative Deaths per 1000 People")


#-----------------------------------------------------------------------------
# additional code for producing policy effects on travel 
#-----------------------------------------------------------------------------
# effects on travel
travel_res <- pte(yname="workplaces_percent_change_from_baseline",
                  gname="group",
                  tname="time.period",
                  idname="fips",
                  data=this_data2,
                  subset_fun=two_by_two_subset,
                  setup_pte_fun=setup_pte_basic,
                  attgt_fun=attgt_fun,
                  cband=FALSE,
                  alp=0.1,
                  weighting_method=weighting_method,
                  ret_imputation=TRUE,
                  matching_formula=BMisc::addCovToFormla("workplaces_percent_change_from_baseline", matching_formula),
                  xformla=BMisc::addCovToFormla("workplaces_percent_change_from_baseline", xformla))

make_pandemic_plot(travel_res, xlabel="", ylabel="Percentage Change in Trips to Work")








